'''
Created on 2012-8-5

@author: DXD.Spirits
'''


from containers import PriorityQueue
from shortcuts import log_info


def dijkstra(graph, origin, destin):
    '''
    @var graph: graph = {node1: edges1, node2: edges2, ...}
                edges = {neighbour1: distance1, neighbour2: distance2, ...}
    @return: path = [origin, node1, node2, ... , destin]
    '''
    prev = {}
    dist = {}
    inf = 0xFFFFFFF
    for node in graph:
        dist[node] = inf
    dist[origin] = 0
    prev[origin] = None
    q = PriorityQueue()
    q.add(origin, 0)
    
    while True:
        u = q.pop()
        if u is None or dist[u] == inf:
            return [], inf
        
        if u == destin:
            path = []
            while u is not None:
                path.append(u)
                u = prev[u]
            path.reverse()
            return path, dist[destin]
        
        for (v, c) in graph[u].iteritems():
            if dist[u] + c < dist[v]:
                dist[v] = dist[u] + c
                q.add(v, dist[v])
                prev[v] = u
    #end


def martins(graph, origin, destin, costfunc = None, K = 3, init = [0] * 3, inf = [0xFFFFFFFF]*3):
    '''
    label = (value0, value1, value2, p, h)
        p: previous node where this label is from
        h: the position in the label list of p where this label could be located
    value = duration, walktime, transfers
    q: the node to whom this label belongs to
    '''
    
    log_info("Martins start with %d performances to optimize, from %s to %s" % (K, origin, destin))
    
    def addl(labelval, arcval):
        return [labelval[i]+arcval[i] for i in range(K)]
    
    if costfunc == None:
        costfunc = addl
    
    def dominate(a, b):
        ''' if a <= b
            if a == b, since b arrive after a with the same value, it's not interesting
        '''
        for i in range(K):
            if a[i] > b[i]:
                return False
        return True
    
    def remove_tlabels(u, i):
        ''' It's SHIT!!!
            Should not remove a label in this way. Improve it.
        '''
        end = len(tlabels[u]) - 1
        if end != i:
            lend = tlabels[u][end]
            heap.add((u, i), lend)
            heap.remove((u, end))
            tlabels[u][i], tlabels[u][end] = tlabels[u][end], tlabels[u][i]
        tlabels[u].pop()
    
    def track_paths():
        path_list = []
        for ldestin in plabels[destin]:
            path = [destin]
            cost = ldestin[:K]
            path_list.append((cost, path))
            label = ldestin
            while label[-2] is not None:
                u, i = label[-2:]
                path.append(u)
                label = plabels[u][i]
            path.reverse()
        return path_list
    
    node_visited_sequence = []
    heap = PriorityQueue()
    ''' What's in the Priority Queue
        (v, j), lv
        v: node
        j: label pos in the node's temperary label list
        lv: label
    '''
    plabels = {node:[] for node in graph}
    tlabels = {node:[] for node in graph}
    
    init_value = init + [None, None]
    tlabels[origin].append(init_value)
    heap.add((origin, 0), init_value)
    
    while True:
        val = heap.pop()
        if val is None:
            log_info("Heap empty, Martins iteration terminates.")
            return track_paths(), node_visited_sequence
        
        u, i = val
        
        #node_visited_sequence.append(u)
        
        lu = tlabels[u][i]
        """ should remove lu from tlabels """
        remove_tlabels(u, i)
        
        if lu >= inf:
            continue
        
        plabels[u].append(lu)
        h = len(plabels[u]) - 1
        
        for (v, cost) in graph[u].iteritems():
            lv = costfunc(lu, cost)
            if lv >= inf or any([dominate(plv, lv) for plv in plabels[destin]]):
                continue
            lv.extend([u, h])
            """ Check if dominated by a permanent/temporary label """
            if not (any([dominate(plv, lv) for plv in plabels[v]]) or 
                    any([dominate(plv, lv) for plv in tlabels[v]])):
                for j, tlv in enumerate(tlabels[v]):
                    if dominate(lv, tlv):
                        heap.remove((v, j))
                        """ should remove lu from tlabels """
                        remove_tlabels(v, j)
                tlabels[v].append(lv)
                j = len(tlabels[v]) - 1
                heap.add((v, j), lv)





if __name__ == "__main__":
    graph = {"A": {"B": (1, 1, 0),
                   "D": (1, 0, 0),
                   },
             "B": {"C": (1, 0, 0),
                   },
             "C": {
                   },
             "D": {"B": (1, 0, 0),
                   "C": (1, 0, 1),
                   },
             "E": {
                   }
             }
    
    def addl(label, cost):
        return [label[i]+cost[i] for i in range(3)]
    
    paths = martins(graph, "A", "C", addl, 3)
    print(paths)
    log_info("!!!")

